Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NAPSU-MQ inference configuration #55

Merged
merged 20 commits into from
Jul 7, 2023
Merged

NAPSU-MQ inference configuration #55

merged 20 commits into from
Jul 7, 2023

Conversation

oraisa
Copy link
Collaborator

@oraisa oraisa commented Jun 8, 2023

Allow changing the inference parameters for NAPSU-MQ. Also add an option to return the InferenceData object from NAPSU-MQ that allows inspecting MCMC diagnostics with arviz.

@oraisa oraisa marked this pull request as draft June 8, 2023 16:43
@oraisa oraisa marked this pull request as ready for review June 9, 2023 11:20
@oraisa oraisa requested a review from lumip June 9, 2023 11:20
@oraisa
Copy link
Collaborator Author

oraisa commented Jun 9, 2023

@lumip The way I used to configure NAPSU-MQ inference is very verbose when changing just one parameter, but it makes the signature of NapsuMQModel.fit simpler, as the config is just one argument. Do you think it is too verbose, in which case I could move some parts of the config to be method arguments.

Example from the notebook:

model = NapsuMQModel(required_marginals=required_marginals)
inference_config = NapsuMQInferenceConfig(
    mcmc_config=NapsuMQMCMCConfig(
        num_samples=1000
    )
)
result = model.fit(
    data=orig_df,
    rng=inference_rng,
    epsilon=1,
    delta=(n ** (-2)),
    inference_config=inference_config,
)

Copy link
Member

@lumip lumip left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overall these are good changes. I've added some thoughts about possible improvements in comments directly at the relevant code.

twinify/napsu_mq/napsu_mq.py Outdated Show resolved Hide resolved
query_sets: Optional[Iterable] = None,
inference_config: NapsuMQInferenceConfig = NapsuMQInferenceConfig(),
show_progress: bool = True,
return_diagnostics: bool = False) -> 'NapsuMQResult':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep in mind to keep the signatures of fit compatible between this and DPVIModel so that they could be used somewhat interchangeable in code. So I think for now it would be good to keep **kwargs to absorb and ignore any unknown arguments here. (And long term we should think if there's a way to make downstream code truly agnostic about the method it is passed for inference.)
From that perspective, maybe NapsuMQModel should keep the inference config as a lifetime variable, i.e., it's passed during initialization instead of here..
Do we expect use cases where we'd like to change the inference configuration often for otherwise the same model?

Copy link
Collaborator Author

@oraisa oraisa Jun 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think using **kwargs to just absorb unknown arguments is a good idea, as that can silently ignore errors, for example misspelled argument names.

The config has a default value, so users can change from DPVI to NAPSU-MQ with the same fit call if they want to use the default config. If they are using a custom NAPSU-MQ config, and want to change to DPVI, they should be required to change the fit call, as the NAPSU-MQ config won't do anything with DPVI.

I changed the show_progressargument to silent for compatibility with DPVI.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can get behind not using **kwargs but then we need to sort out the differences in API between the two methods. With the changes here, NAPSU-MQ would have the following extraneous arguments that are not present for DPVI (or the base class InferenceModel):

  • query_sets
  • inference_config
  • return_diagnostics

We could make return_diagnostics a common argument, that could be handy for DPVI as well - alternatively, we could maybe include the diagnostics as part of the result object.

We could do similarly for inference_config, but I prefer making it an argument for NapsuMQModel.__init__, i.e., making a NapsuMQModel instance completely encapsulate all details of its inference process (DPVIModel already mostly functions that way).

What is actually the intended difference between query_sets here and required_marginals passed in the __init__? Could the be unified or both lifted into __init__.

Finally, I have reverted your change for show_progress and adapted DPVIModel accordingly (also removed the verbose argument there) - I think show_progress is the better/more descriptive argument name here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference between query_setsand required_marginals is that one chooses the queries explicitly and the other specifies queries that are always included when other queries are selected automatically. I renamed them to make this clear, and put both to __init__.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also moved inference_configto __init__.

twinify/napsu_mq/napsu_mq.py Outdated Show resolved Hide resolved
@oraisa oraisa requested a review from lumip June 15, 2023 16:44
Copy link
Member

@lumip lumip left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now looks good to me, thanks. I have made some small additional changes to InferenceModel and DPVIModel so that all fit function now share the same arguments.

I'll do a rebase and some cleanup and then merge it.

@lumip lumip force-pushed the napsu_mq_inference_config branch from 8ba9ff7 to 90b51a0 Compare July 6, 2023 15:54
@lumip lumip merged commit fb1f65c into master Jul 7, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants